#!/usr/bin/python
# Copyright 2020 Makani Technologies LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Generate the switch_info.c file."""

import os
import sys
import textwrap

from makani.avionics.common import network_config
from makani.avionics.network import aio_node
from makani.avionics.network import eop_message_type
from makani.avionics.network import message_type
from makani.avionics.network import network_config as net_config
from makani.avionics.network import network_util
from makani.avionics.network import winch_message_type
from makani.lib.python import c_helpers


class IpAddressException(Exception):
  pass


class SwitchMismatchException(Exception):
  pass


class TrunkException(Exception):
  pass


def _WriteSwitchInfo(autogen_root, output_dir, script_name, config):
  """Write switch_info.c.

  Args:
    autogen_root: The MAKANI_HOME-equivalent top directory.
    output_dir: The directory in which to output the files.
    script_name: This script's filename.
    config: A NetworkConfig.

  Raises:
    IpAddressException: On any IP address error.
    SwitchMismatchException: If the TMS570 port of a switch is invalid.
    TrunkException: If there is a problem with the trunk configuration.
  """
  file_without_extension = os.path.join(output_dir, 'switch_info')
  rel_path = os.path.relpath(file_without_extension, autogen_root)

  parts = [textwrap.dedent("""
      // Generated by {name}; do not edit.

      #include "{header}"

      #include <assert.h>
      #include <stdbool.h>
      #include <stdint.h>
      #include <string.h>

      """.format(name=script_name, header=rel_path + '.h'))]

  ip_addr = {node.snake_name: node.ip_octet for node in config.aio_nodes}
  switches = config.GetSwitches()
  switches_sorted = sorted(switches.iterkeys())
  message_types = config.all_messages

  def _GetTms570Node(switch):
    if 'config' not in switch or 'tms570' not in switch['config']:
      return None
    port = switch['ports'][switch['config']['tms570']]
    if not port.startswith('aio_nodes.'):
      raise SwitchMismatchException('No AIO node connected to TMS570 port.')
    return port[len('aio_nodes.'):]

  # The VLAN numbering convention here uses the lowest of the last octet of the
  # IP addresses of the TMS570s on the two connected switches, concatenated
  # with the port number on the associated switch.  Since the core switches have
  # 27 ports we need 5 bits to represent the port number.  As a consequence
  # TMS570s with last octets greater than 127 will present a problem for this
  # scheme.  Additionally the IP with last octet 0 is reserved for other VLANs.
  def _GetSegId(node, port, remote_node, remote_port):
    ip = ip_addr[node]
    remote_ip = ip_addr[remote_node]
    if remote_ip < ip:
      ip = remote_ip
      port = remote_port
    if ip <= 0 or ip > 127:
      raise IpAddressException('Invalid IP octet 0 >= IP or IP > 127.')
    return (ip << 5) | int(port)

  def _GenerateMask(port_list):
    return sum(1 << i for i in port_list)

  path_finder = network_util.PathFinder(switches, None, network_c=True)
  net_c_forward = network_util.MakeNetworkCForwardingMap(path_finder)

  for switch_name in switches_sorted:
    switch = switches[switch_name]
    # We use simple configs for virtual or managed switches.
    if 'config' not in switch or 'tms570' not in switch['config']:
      continue
    config = switch['config']
    node = _GetTms570Node(switch)
    seg_ids = [0] * config['chip']['num_ports']
    for port, remote in switch['ports'].iteritems():
      if remote.startswith('switches.'):
        _, remote_switch_name, remote_port = remote.split('.')
        remote_switch = switches[remote_switch_name]
        remote_node = _GetTms570Node(remote_switch)
        if not remote_node or remote_node not in ip_addr:
          continue
        seg_ids[int(port)] = _GetSegId(node, port, remote_node, remote_port)
    info = {}
    info['node_camel'] = c_helpers.SnakeToCamel(node)
    info['switch_type'] = 'kSwitchType' + c_helpers.SnakeToCamel(
        config['chip']['type'])
    info['segment_vlans'] = 'segmentVlanIds' + info['node_camel']
    info['num_ports'] = config['chip']['num_ports']
    info['num_fiber_ports'] = config['chip']['num_fiber_ports']
    parts.append(textwrap.fill('static const uint16_t {}[{}] = {{{}}};'.format(
        info['segment_vlans'], info['num_ports'],
        ', '.join(str(i) for i in seg_ids)), 80, subsequent_indent='  '))

    # Trunk configuration.
    trunk_ports = set()
    select_ports = set()
    trunk_unicast_learning_ports = set()
    info['num_multicast_overrides'] = 0
    info['multicast_overrides'] = 'NULL'
    if 'trunk' in config:
      trunk_ports = config['trunk']['ports']
      if 'unicast_learning' in config['trunk']:
        trunk_unicast_learning_ports = config['trunk']['unicast_learning']
      if 'override_message_routes' in config['trunk']:
        mcast_overrides = []
        override = config['trunk']['override_message_routes']
        for override_type, override_ports in override.iteritems():
          mask = _GenerateMask(override_ports)
          for m in message_types:
            if m.name == override_type:
              if m.eop_message:
                mac = network_config.EopMessageTypeToEthernetAddress(
                    eop_message_type.__dict__[
                        'kEopMessageType' + override_type])
              elif m.winch_message:
                mac = network_config.WinchMessageTypeToEthernetAddress(
                    winch_message_type.__dict__[
                        'kWinchMessageType' + override_type])
              else:
                mac = network_config.AioMessageTypeToEthernetAddress(
                    message_type.__dict__['kMessageType' + override_type])
              break
          else:
            raise TrunkException('Override contains invalid message type %s.'
                                 % override_type)
          mcast_overrides.append(('{%s}' % ', '.join(
              '0x%02X'%x for x in [mac.a, mac.b, mac.c, mac.d, mac.e, mac.f]),
                                  mask))
        info['multicast_overrides'] = 'multicastOverrides' + info['node_camel']
        info['num_multicast_overrides'] = len(mcast_overrides)
        parts.append(
            'static const TrunkMulticastOverride {}[{}] = {{\n{}\n}};'.format(
                info['multicast_overrides'], info['num_multicast_overrides'],
                ',\n'.join('  {%s, 0x%08X}' % o for o in mcast_overrides)))
      if 'select_default_ports' in config['trunk']:
        select_ports = config['trunk']['select_default_ports']
      else:
        select_ports = trunk_ports

    info['forward_mask_a'] = _GenerateMask(config['network_a'])
    info['forward_mask_b'] = _GenerateMask(config['network_b'])
    info['forward_mask_c'] = net_c_forward[switch_name]
    info['egress_mask_c'] = _GenerateMask(config.get('network_c', []))
    if 'isolate' in config:
      if trunk_ports:
        raise TrunkException('Isolate and trunk definitions are currently not '
                             'supported together.')
      else:
        isolate_ports = config['isolate']
    else:
      isolate_ports = trunk_ports
    info['isolate_mask'] = _GenerateMask(isolate_ports)
    info['trunk_mask'] = _GenerateMask(trunk_ports)
    info['unicast_learning_mask'] = _GenerateMask(trunk_unicast_learning_ports)
    info['select_default_mask'] = _GenerateMask(select_ports)
    info['unicast_mask'] = _GenerateMask(config['unicast'])
    info['host_port'] = config['tms570']
    info['mirror_port'] = config['mirror']

    parts.append(
        'static const SwitchInfo switchInfo{node_camel} = {{\n'
        '  {switch_type}, {num_ports}, {num_fiber_ports},\n'
        '  {host_port}, {mirror_port},\n'
        '  0x{forward_mask_a:X}, 0x{forward_mask_b:X},\n'
        '  0x{forward_mask_c:X}, 0x{egress_mask_c:X},\n'
        '  0x{isolate_mask:X}, 0x{unicast_mask:X}, {segment_vlans},\n'
        '  {{0x{trunk_mask:X}, 0x{select_default_mask:X},'
        ' 0x{unicast_learning_mask:X},\n'
        '   {num_multicast_overrides}, {multicast_overrides}}}\n'
        '}};\n'.format(**info))

  parts.append(textwrap.dedent("""
      const SwitchInfo *GetSwitchInfo(AioNode node) {
        switch (node) {"""[1:]))

  unused_nodes = set(c_helpers.EnumHelper('AioNode', aio_node).Names())

  for switch_name in switches_sorted:
    switch = switches[switch_name]
    if 'config' not in switch or 'tms570' not in switch['config']:
      continue
    node = _GetTms570Node(switch)
    node_camel = c_helpers.SnakeToCamel(node)
    node_name = 'kAioNode' + node_camel
    unused_nodes.remove(node_name)
    parts.append('    case %s:' % node_name)
    if 'config' not in switch:
      parts.append('      assert(false);')
      parts.append('      return NULL;')
    else:
      parts.append('      return &switchInfo%s;' % node_camel)

  parts.append('    // Fall-through intentional.')
  for node in sorted(unused_nodes):
    parts.append('    case %s:' % node)
  parts.append(textwrap.dedent("""
          case kAioNodeForceSigned:
          case kNumAioNodes:
          default:
            assert(false);
            return NULL;
        }
      }
  """[1:]))

  with open(file_without_extension + '.c', 'w') as f:
    f.write('\n'.join(parts))


def main(argv):
  flags, argv = net_config.ParseGenerationFlags(argv)

  config = net_config.NetworkConfig(flags.network_file)
  script_name = os.path.basename(argv[0])
  _WriteSwitchInfo(flags.autogen_root, flags.output_dir, script_name, config)

if __name__ == '__main__':
  main(sys.argv)
